{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pBKdH1A6EKoH"
      },
      "source": [
        "# Demo code for \"[Similarity of Neural Network Representations Revisited](https://arxiv.org/abs/1905.00414)\"\n",
        "\n",
        "Copyright 2019 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n",
        "Please cite as:\n",
        "\n",
        "    @inproceedings{pmlr-v97-kornblith19a,\n",
        "      title = {Similarity of Neural Network Representations Revisited},\n",
        "      author = {Kornblith, Simon and Norouzi, Mohammad and Lee, Honglak and Hinton, Geoffrey},\n",
        "      booktitle = {Proceedings of the 36th International Conference on Machine Learning},\n",
        "      pages = {3519--3529},\n",
        "      year = {2019},\n",
        "      volume = {97},\n",
        "      month = {09--15 Jun},\n",
        "      publisher = {PMLR}\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "MkucRi3yn7UJ"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "def gram_linear(x):\n",
        "  \"\"\"Compute Gram (kernel) matrix for a linear kernel.\n",
        "\n",
        "  Args:\n",
        "    x: A num_examples x num_features matrix of features.\n",
        "\n",
        "  Returns:\n",
        "    A num_examples x num_examples Gram matrix of examples.\n",
        "  \"\"\"\n",
        "  return x.dot(x.T)\n",
        "\n",
        "\n",
        "def gram_rbf(x, threshold=1.0):\n",
        "  \"\"\"Compute Gram (kernel) matrix for an RBF kernel.\n",
        "\n",
        "  Args:\n",
        "    x: A num_examples x num_features matrix of features.\n",
        "    threshold: Fraction of median Euclidean distance to use as RBF kernel\n",
        "      bandwidth. (This is the heuristic we use in the paper. There are other\n",
        "      possible ways to set the bandwidth; we didn't try them.)\n",
        "\n",
        "  Returns:\n",
        "    A num_examples x num_examples Gram matrix of examples.\n",
        "  \"\"\"\n",
        "  dot_products = x.dot(x.T)\n",
        "  sq_norms = np.diag(dot_products)\n",
        "  sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]\n",
        "  sq_median_distance = np.median(sq_distances)\n",
        "  return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))\n",
        "\n",
        "\n",
        "def center_gram(gram, unbiased=False):\n",
        "  \"\"\"Center a symmetric Gram matrix.\n",
        "\n",
        "  This is equvialent to centering the (possibly infinite-dimensional) features\n",
        "  induced by the kernel before computing the Gram matrix.\n",
        "\n",
        "  Args:\n",
        "    gram: A num_examples x num_examples symmetric matrix.\n",
        "    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased\n",
        "      estimate of HSIC. Note that this estimator may be negative.\n",
        "\n",
        "  Returns:\n",
        "    A symmetric matrix with centered columns and rows.\n",
        "  \"\"\"\n",
        "  if not np.allclose(gram, gram.T):\n",
        "    raise ValueError('Input must be a symmetric matrix.')\n",
        "  gram = gram.copy()\n",
        "\n",
        "  if unbiased:\n",
        "    # This formulation of the U-statistic, from Szekely, G. J., \u0026 Rizzo, M.\n",
        "    # L. (2014). Partial distance correlation with methods for dissimilarities.\n",
        "    # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically\n",
        "    # stable than the alternative from Song et al. (2007).\n",
        "    n = gram.shape[0]\n",
        "    np.fill_diagonal(gram, 0)\n",
        "    means = np.sum(gram, 0, dtype=np.float64) / (n - 2)\n",
        "    means -= np.sum(means) / (2 * (n - 1))\n",
        "    gram -= means[:, None]\n",
        "    gram -= means[None, :]\n",
        "    np.fill_diagonal(gram, 0)\n",
        "  else:\n",
        "    means = np.mean(gram, 0, dtype=np.float64)\n",
        "    means -= np.mean(means) / 2\n",
        "    gram -= means[:, None]\n",
        "    gram -= means[None, :]\n",
        "\n",
        "  return gram\n",
        "\n",
        "\n",
        "def cka(gram_x, gram_y, debiased=False):\n",
        "  \"\"\"Compute CKA.\n",
        "\n",
        "  Args:\n",
        "    gram_x: A num_examples x num_examples Gram matrix.\n",
        "    gram_y: A num_examples x num_examples Gram matrix.\n",
        "    debiased: Use unbiased estimator of HSIC. CKA may still be biased.\n",
        "\n",
        "  Returns:\n",
        "    The value of CKA between X and Y.\n",
        "  \"\"\"\n",
        "  gram_x = center_gram(gram_x, unbiased=debiased)\n",
        "  gram_y = center_gram(gram_y, unbiased=debiased)\n",
        "\n",
        "  # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or\n",
        "  # n*(n-3) (unbiased variant), but this cancels for CKA.\n",
        "  scaled_hsic = gram_x.ravel().dot(gram_y.ravel())\n",
        "\n",
        "  normalization_x = np.linalg.norm(gram_x)\n",
        "  normalization_y = np.linalg.norm(gram_y)\n",
        "  return scaled_hsic / (normalization_x * normalization_y)\n",
        "\n",
        "\n",
        "def _debiased_dot_product_similarity_helper(\n",
        "    xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y,\n",
        "    n):\n",
        "  \"\"\"Helper for computing debiased dot product similarity (i.e. linear HSIC).\"\"\"\n",
        "  # This formula can be derived by manipulating the unbiased estimator from\n",
        "  # Song et al. (2007).\n",
        "  return (\n",
        "      xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)\n",
        "      + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))\n",
        "\n",
        "\n",
        "def feature_space_linear_cka(features_x, features_y, debiased=False):\n",
        "  \"\"\"Compute CKA with a linear kernel, in feature space.\n",
        "\n",
        "  This is typically faster than computing the Gram matrix when there are fewer\n",
        "  features than examples.\n",
        "\n",
        "  Args:\n",
        "    features_x: A num_examples x num_features matrix of features.\n",
        "    features_y: A num_examples x num_features matrix of features.\n",
        "    debiased: Use unbiased estimator of dot product similarity. CKA may still be\n",
        "      biased. Note that this estimator may be negative.\n",
        "\n",
        "  Returns:\n",
        "    The value of CKA between X and Y.\n",
        "  \"\"\"\n",
        "  features_x = features_x - np.mean(features_x, 0, keepdims=True)\n",
        "  features_y = features_y - np.mean(features_y, 0, keepdims=True)\n",
        "\n",
        "  dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2\n",
        "  normalization_x = np.linalg.norm(features_x.T.dot(features_x))\n",
        "  normalization_y = np.linalg.norm(features_y.T.dot(features_y))\n",
        "\n",
        "  if debiased:\n",
        "    n = features_x.shape[0]\n",
        "    # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.\n",
        "    sum_squared_rows_x = np.einsum('ij,ij-\u003ei', features_x, features_x)\n",
        "    sum_squared_rows_y = np.einsum('ij,ij-\u003ei', features_y, features_y)\n",
        "    squared_norm_x = np.sum(sum_squared_rows_x)\n",
        "    squared_norm_y = np.sum(sum_squared_rows_y)\n",
        "\n",
        "    dot_product_similarity = _debiased_dot_product_similarity_helper(\n",
        "        dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,\n",
        "        squared_norm_x, squared_norm_y, n)\n",
        "    normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(\n",
        "        normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,\n",
        "        squared_norm_x, squared_norm_x, n))\n",
        "    normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(\n",
        "        normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,\n",
        "        squared_norm_y, squared_norm_y, n))\n",
        "\n",
        "  return dot_product_similarity / (normalization_x * normalization_y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WP8XuJkmEXLe"
      },
      "source": [
        "## Tutorial"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZJv-cq-j-a9Y"
      },
      "source": [
        "First, we generate some random data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "zeI3il6JRM-L"
      },
      "outputs": [],
      "source": [
        "np.random.seed(1337)\n",
        "X = np.random.randn(100, 10)\n",
        "Y = np.random.randn(100, 10) + X"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "obJCdw02-dcn"
      },
      "source": [
        "Linear CKA can be computed either based on dot products between examples or dot products between features:\n",
        "$$\\langle\\text{vec}(XX^\\text{T}),\\text{vec}(YY^\\text{T})\\rangle = ||Y^\\text{T}X||_\\text{F}^2$$\n",
        "The formulation based on similarities between features (right-hand side) is faster than the formulation based on similarities between similarities between examples (left-hand side) when the number of examples exceeds the number of features. We provide both formulations here and demonstrate that they are equvialent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "height": 51
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 256,
          "status": "ok",
          "timestamp": 1559866109801,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "45qb6zdSsHj6",
        "outputId": "47b55b33-d54c-4c6d-80db-def0f4a91269"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Linear CKA from Examples: 0.55761\n",
            "Linear CKA from Features: 0.55761\n"
          ]
        }
      ],
      "source": [
        "cka_from_examples = cka(gram_linear(X), gram_linear(Y))\n",
        "cka_from_features = feature_space_linear_cka(X, Y)\n",
        "\n",
        "print('Linear CKA from Examples: {:.5f}'.format(cka_from_examples))\n",
        "print('Linear CKA from Features: {:.5f}'.format(cka_from_features))\n",
        "np.testing.assert_almost_equal(cka_from_examples, cka_from_features)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3R4L1QH5_JEd"
      },
      "source": [
        "It is also possible to compute CKA with nonlinear kernels. Here, we use an RBF kernel with the bandwidth set to $\\frac{1}{2}$ the median distance in the distance matrix."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 300,
          "status": "ok",
          "timestamp": 1559866110174,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "2iHcuAMQs7RM",
        "outputId": "5fdb64bc-8f55-4edf-d5ce-0ecbe32d70fd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "RBF CKA: 0.65483\n"
          ]
        }
      ],
      "source": [
        "rbf_cka = cka(gram_rbf(X, 0.5), gram_rbf(Y, 0.5))\n",
        "print('RBF CKA: {:.5f}'.format(rbf_cka))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VtXdefE-ARRM"
      },
      "source": [
        "If the number of examples is small, it might help to compute a \"debiased\" form of CKA. This form of CKA can be obtained by recognizing that HSIC can be formulated as a U-statistic, as in [Song et al., 2007](https://arxiv.org/pdf/0704.2668.pdf), and replacing biased estimators of HSIC in the numerator and denominator with this unbiased estimator. With some further algebraic manipulation, we also derived an unbiased estimator based on similarities between features rather than the Gram matrices. The resulting estimator of CKA is still generally biased, but the bias is reduced."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "height": 51
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 289,
          "status": "ok",
          "timestamp": 1559866110554,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "jInZLKpx3fV8",
        "outputId": "31d3d2e1-f62b-40b0-f3fa-bc89cb9d7eb8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Linear CKA from Examples (Debiased): 0.51346\n",
            "Linear CKA from Features (Debiased): 0.51346\n"
          ]
        }
      ],
      "source": [
        "cka_from_examples_debiased = cka(gram_linear(X), gram_linear(Y), debiased=True)\n",
        "cka_from_features_debiased = feature_space_linear_cka(X, Y, debiased=True)\n",
        "\n",
        "print('Linear CKA from Examples (Debiased): {:.5f}'.format(\n",
        "    cka_from_examples_debiased))\n",
        "print('Linear CKA from Features (Debiased): {:.5f}'.format(\n",
        "    cka_from_features_debiased))\n",
        "\n",
        "np.testing.assert_almost_equal(cka_from_examples_debiased,\n",
        "                               cka_from_features_debiased)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QZNV8MGQTaSp"
      },
      "source": [
        "## CKA vs. CCA\n",
        "\n",
        "Below, we show how to compute the mean squared CCA correlation $R^2_\\text{CCA}$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 329,
          "status": "ok",
          "timestamp": 1559866110947,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "3C3oF6vrECHt",
        "outputId": "11469eb9-47e1-4032-96b8-994ba6f56770"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mean Squared CCA Correlation: 0.54312\n"
          ]
        }
      ],
      "source": [
        "def cca(features_x, features_y):\n",
        "  \"\"\"Compute the mean squared CCA correlation (R^2_{CCA}).\n",
        "\n",
        "  Args:\n",
        "    features_x: A num_examples x num_features matrix of features.\n",
        "    features_y: A num_examples x num_features matrix of features.\n",
        "\n",
        "  Returns:\n",
        "    The mean squared CCA correlations between X and Y.\n",
        "  \"\"\"\n",
        "  qx, _ = np.linalg.qr(features_x)  # Or use SVD with full_matrices=False.\n",
        "  qy, _ = np.linalg.qr(features_y)\n",
        "  return np.linalg.norm(qx.T.dot(qy)) ** 2 / min(\n",
        "      features_x.shape[1], features_y.shape[1])\n",
        "\n",
        "print('Mean Squared CCA Correlation: {:.5f}'.format(cca(X, Y)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wLQIIRhiGrQy"
      },
      "source": [
        "## Invariance Properties\n",
        "\n",
        "Finally, we verify the invariance properties of CKA and compare them to CCA."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TGrluBcJGtbn"
      },
      "outputs": [],
      "source": [
        "transform = np.random.randn(10, 10)\n",
        "_, orthogonal_transform = np.linalg.eigh(transform.T.dot(transform))\n",
        "\n",
        "# CKA is invariant only to orthogonal transformations.\n",
        "np.testing.assert_almost_equal(\n",
        "    feature_space_linear_cka(X, Y),\n",
        "    feature_space_linear_cka(X.dot(orthogonal_transform), Y))\n",
        "np.testing.assert_(not np.isclose(\n",
        "    feature_space_linear_cka(X, Y),\n",
        "    feature_space_linear_cka(X.dot(transform), Y)))\n",
        "\n",
        "# CCA is invariant to any invertible linear transform.\n",
        "np.testing.assert_almost_equal(cca(X, Y), cca(X.dot(orthogonal_transform), Y))\n",
        "np.testing.assert_almost_equal(cca(X, Y), cca(X.dot(transform), Y))\n",
        "\n",
        "# Both CCA and CKA are invariant to isotropic scaling.\n",
        "np.testing.assert_almost_equal(cca(X, Y), cca(X * 1.337, Y))\n",
        "np.testing.assert_almost_equal(\n",
        "    feature_space_linear_cka(X, Y),\n",
        "    feature_space_linear_cka(X * 1.337, Y))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Similarity of Neural Network Representations Revisited Demo.ipynb",
      "provenance": [],
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 2",
      "name": "python2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
